package main.java.exercise;

import main.java.framework.StudentInformation;
import main.java.framework.StudentSolution;
import main.java.framework.solver.Solver;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

public class StudentSolutionImplementation implements StudentSolution {
    @Override
    public StudentInformation provideStudentInformation() {
        return new StudentInformation(
                "", // Vorname
                "", // Nachname
                "" // Matrikelnummer
        );
    }

    private String constructDIMACS(int[] articles, int[][] bundles) {
        StringBuilder lines = new StringBuilder();

        lines.append(String.format("p cnf %d %d\n", bundles.length, articles.length));
        for (int i = 0; i < articles.length; i++) {
            for (int j = 0; j < bundles.length; j++)
                for (int k = 0; k < bundles[j].length; k++)
                    if (bundles[j][k] == articles[i]) {
                        lines.append(" ").append(j + 1);
                        break;
                    }
            lines.append(" 0\n");
        }

        return lines.toString();
    }


    private boolean checkSolverResult(String result) {
        if (result.equals("") || result.contains("error"))
            return false;
        return true;
    }

    private String find(String initial, Solver solver, int n, int k, AtomicInteger lowestK) {
        if (k >= n)
            if (checkSolverResult(solver.solve(initial)))
                return initial;
            else
                return "INVALID";
        if (k <= 0 || k >= lowestK.get())
            return "INVALID";

        boolean satisfiable = false;
        // prepare strings for manipulation
        String head = null;
        String body = null;
        String solution = null;

        // generate initial configuration
        int[] combination = new int[k];
        for (int i = 0; i < k; i++)
            combination[i] = i;

        // generate combinations on the fly in lexicographic order
        // credits partially go to
        // https://www.baeldung.com/java-combinations-algorithm
        while (combination[k - 1] < n) {
            // do something with combination
            head = initial.substring(0, initial.indexOf("\n"));
            body = initial.substring(initial.indexOf("\n"));
            int j = 0;
            for (int i = 0; i < n; i++) {
                if (j < combination.length && i == combination[j]) {
                    ++j;
                    continue;
                }
                // cut i from initial solution copy and continue
                String remove = String.format(" %d ", i + 1);
                body = body.replaceAll(remove, "   ");
            }

            solution = String.format("%s%s", head, body);
            if ((satisfiable = checkSolverResult(solver.solve(solution))))
                break;

            // next combination
            int l = k - 1;
            while (l != 0 && combination[l] == n - k + l)
                l--;
            combination[l]++;
            for (int i = l + 1; i < k; i++)
                combination[i] = combination[i - 1] + 1;
        }

        if (satisfiable) {
            //System.out.println(String.format("%sk%d", solution.replaceAll("\n", "; "), k));
            // update k
            if (k < lowestK.get())
                lowestK.set(k);
            // if this is already an optimal solution, quit trying
            if (k == 1)
                return solution;
            // else
            String otherSolution = find(initial, solver, n, k / 2, lowestK);
            return (!otherSolution.equals("INVALID")) ? otherSolution : solution;
        } else {
            int offset = (lowestK.get() - k) / 2;
            return find(initial, solver, n, offset == 0 ? k + 1 : k + offset, lowestK);
        }

        /*
         * general idea:
         * try all and find out if exists
         * if already upmost
         *   return this
         * if exists
         *   if deepest
         *     return this
         *   go deeper
         *   return deeper or this
         * else
         *   go upper
         *   return upper
         */
    }

    public boolean findBundles(int[] articleIds, int[][] articleBundles, Solver solver, boolean[] chosenBundles) {
        /*
        System.out.println(Arrays.toString(articleBundles));
        for (int[] i : articleBundles)
            System.out.print(Arrays.toString(i) + ", ");
        System.out.println();
        System.out.println(Arrays.toString(articleIds));
         */

        String res = find(
                constructDIMACS(articleIds, articleBundles),
                solver,
                articleBundles.length,
                articleBundles.length / 2,
                new AtomicInteger(articleBundles.length)
        );

        if (res.equals("INVALID"))
            return false;

        System.out.println(res.replaceAll("\n", "; ").replaceAll(" {2}", " "));

        // strip first line
        res = res.substring(res.indexOf("\n"));
        boolean total = false;
        for (int i = 1; i <= articleBundles.length; i++) {
            String num = String.format(" %d ", i);
            chosenBundles[i - 1] = res.contains(num);
            total = total || chosenBundles[i - 1];
        }
        /*
        System.out.println(res);
        System.out.println(Arrays.toString(chosenBundles));
        System.out.println(total);
         */

        return total;
    }

    /*
    private String findOld(String initial, int[][] bundles, Solver solver, ArrayList<Integer> left, AtomicInteger smallestK, int k, int depth) {
        // no variable left
        if (depth < 0)
            return "INVALID";

        System.out.print(initial.replaceAll("\n", "; "));
        System.out.println(String.format("d%d; lk%d; gk%d; %s", depth, k, smallestK.get(), left.stream().map(i -> i + 1).collect(Collectors.toList())));

        String solution = initial, without = initial, with = initial;

        if (smallestK.get() > k) {
            String solverResult = solver.solve(solution);
            if (!checkSolverResult(solverResult))
                return "INVALID";
            System.out.println(smallestK.get());
            System.out.println(k);
            smallestK.set(k);
            System.out.println(smallestK.get());
        }

        int previousK = smallestK.get(), k1, k2;
        int first = left.size() != 0 ? left.remove(0) + 1 : -1;

        // skip no-replace steps
        if (first > 0) {
            // dont strip header
            String head = initial.substring(0, initial.indexOf("\n"));
            String body = initial.substring(initial.indexOf("\n"));
            String replace = String.format(" %d ", first);
            without = find(head + body.replaceAll(replace, "   "), bundles, solver, left, smallestK, k - 1, depth - 1);
        }
        k1 = (smallestK.get() < previousK) && (!without.equals("INVALID")) ? smallestK.get() : Integer.MAX_VALUE;

        with = find(initial, bundles, solver, left, smallestK, k, depth - 1);
        k2 = (smallestK.get() < k1) && (!with.equals("INVALID")) ? smallestK.get() : Integer.MAX_VALUE;

        // put back temporarily removed element
        if (first > 0)
            left.add(0, first - 1);

        if (k1 < previousK) {
            solution = without;
            previousK = k2;
        }

        if (k2 < previousK)
            solution = with;

        return solution;
    }

    public boolean findBundlesOld(int[] articleIds, int[][] articleBundles, Solver solver, boolean[] chosenBundles) {
        ArrayList<Integer> bundlesLeft = new ArrayList<>(articleBundles.length);
        for (int i = 0; i < articleBundles.length; i++)
            bundlesLeft.add(i);

        System.out.println(bundlesLeft);
        System.out.println(Arrays.toString(articleBundles));
        for (int[] i : articleBundles)
            System.out.print(Arrays.toString(i) + ", ");
        System.out.println();
        System.out.println(Arrays.toString(articleIds));

        String res = find(
                constructDIMACS(articleIds, articleBundles),
                articleBundles,
                solver,
                bundlesLeft,
                new AtomicInteger(articleBundles.length),
                articleBundles.length,
                articleBundles.length
        );
        // strip first line
        res = res.substring(res.indexOf("\n"));
        boolean total = false;
        for (int i = 1; i <= articleBundles.length; i++) {
            String num = String.format(" %d ", i);
            chosenBundles[i - 1] = res.contains(num);
            total = total || !chosenBundles[i - 1];
        }
        System.out.println(res);
        System.out.println(Arrays.toString(chosenBundles));
        System.out.println(total);

        return total;
    }
     */

}
